Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Implements RNNT+MMI #1030

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft

[WIP] Implements RNNT+MMI #1030

wants to merge 9 commits into from

Conversation

pkufool
Copy link
Collaborator

@pkufool pkufool commented Aug 9, 2022

It runs normally in my self-constructed test case, not fully tested yet, though.

The sampled paths:

sampled_paths = torch.tensor([ [ [ 3, 5, 0, 4, 6, 0, 2, 1 ],
                                 [ 2, 0, 5, 4, 0, 6, 1, 2 ],
                                 [ 3, 5, 2, 0, 0, 1, 6, 4 ]],
                               [ [ 7, 0, 4, 0, 6, 0, 3, 0 ],
                                 [ 0, 7, 3, 0, 2, 0, 4, 5 ],
                                 [ 7, 0, 3, 4, 0, 1, 2, 0 ]]], dtype=torch.int32)

The corresponding lattice:
image

image

Note: There is an arc from state 2 to state 17 in the second lattice, because the last symbol of the second path of second sequence is sampled at frame 1, it is a simulation of reaching final frame.

Copy link
Collaborator Author

@pkufool pkufool left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danpovey Do you have any good idea to test this function, I can only think of constructing simple test cases.

repeat_num = us_row_splits1_data[us_idx0 + 1] -
us_row_splits1_data[us_idx0];

arc.score = -logf(1 - powf(1 - sampling_prob, repeat_num));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only include the "predictor" head output in C++ part, the other two scores (i.e. hybrid output and lm_output) will add on python part, it would be easier to enable autograd for hybrid output.

a_value = getattr(lattice, "scores")
# Enable autograd for path_scores
b_value = index_select(path_scores.flatten(), arc_map)
value = a_value + b_value
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

path_scores here will contain hybrid_output and detached lm_output. I include the path_scores here and enable antograd to path_scores.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, OK. Right, we treat those as differentiable, but the negated sampling_prob is treated as just a constant.

# index == 0 means the sampled symbol is blank
t_mask = index == 0
# t_index = torch.where(t_mask, t_index + 1, t_index)
t_index = t_index + 1
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use regular RNN-T, it is possible to generate too many symbols for a specific frame, and that might be chances to generate a lattice containing cycles, which is not expected. I am not sure whether we will encounter such a issue at the very beginning of training.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, a valid point. Yes, computing forward backward scores would not work correctly if there are cycles. One possibility would be to augment the state with a sub-frame, i.e. instead of (ctx, t) it becomes (ctx, t, sub_t) with sub_t = (0, 1, 2, ..). That would prevent cycles, although it might prevent a small number of paths from recombining that might otherwise be able to recombine.

@pkufool pkufool marked this pull request as draft August 17, 2022 02:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants